{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04251.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04251.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uGGtdskZ7pjE",
        "colab_type": "text"
      },
      "source": [
        "## 3.3 线性回归的简洁实现"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IMjglK8n8nWz",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.1 生成数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gTghY6CM8y5P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn \n",
        "import numpy as np \n",
        "torch.manual_seed(1)\n",
        "\n",
        "torch.set_default_tensor_type('torch.FloatTensor')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HV1VIDxjYQUH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_inputs = 2 \n",
        "num_examples = 1000 \n",
        "true_w = [2, -3.4]\n",
        "true_b = 4.2 \n",
        "features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)\n",
        "labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b \n",
        "labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hneiHIoegGL8",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.2 读取数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BmRqI2gkHKpE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch.utils.data as Data \n",
        "\n",
        "batch_size = 10 \n",
        "\n",
        "dataset = Data.TensorDataset(features, labels)\n",
        "\n",
        "data_iter = Data.DataLoader(\n",
        "    dataset=dataset, \n",
        "    batch_size=batch_size, \n",
        "    shuffle=True, \n",
        "    num_workers=4, \n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "17GEWnGPNx_L",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "1abc635c-1104-40cc-8b46-5cea73295116"
      },
      "source": [
        "for X, y in data_iter:\n",
        "    print(X, '\\n', y)\n",
        "    break "
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[-0.1649, -1.7607],\n",
            "        [ 1.6516, -0.2547],\n",
            "        [-0.4610,  0.5266],\n",
            "        [ 0.7176, -0.3589],\n",
            "        [ 0.1484, -1.2508],\n",
            "        [ 1.0260,  0.1790],\n",
            "        [-0.5323,  0.1450],\n",
            "        [-1.5016, -0.3048],\n",
            "        [ 1.0703,  1.8395],\n",
            "        [ 3.5733,  0.1565]]) \n",
            " tensor([ 9.8718,  8.3559,  1.4954,  6.8732,  8.7421,  5.6415,  2.6326,  2.2218,\n",
            "         0.0736, 10.8089])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DYUQgO28N95I",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.3 定义模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7SI3EwmyOFu_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "5e942edc-6eeb-4a14-806b-006e66092926"
      },
      "source": [
        "class LinearNet(nn.Module):\n",
        "    def __init__(self, n_feature):\n",
        "        super(LinearNet, self).__init__()\n",
        "        self.linear = nn.Linear(n_feature, 1)   # 阿拉伯数字 1，输出 1 个。\n",
        "\n",
        "    def forward(self, x):\n",
        "        y = self.linear(x)\n",
        "        return y \n",
        "\n",
        "net = LinearNet(num_inputs)\n",
        "print(net)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "LinearNet(\n",
            "  (linear): Linear(in_features=2, out_features=1, bias=True)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kXGCoOOtO89y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "e8b88579-a97a-4848-a17c-8225c57bf9ca"
      },
      "source": [
        "# 上一个是内部实现，以这个为准\n",
        "net = nn.Sequential(\n",
        "    nn.Linear(num_inputs, 1)\n",
        ")\n",
        "\n",
        "print(net)\n",
        "print(net[0])"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sequential(\n",
            "  (0): Linear(in_features=2, out_features=1, bias=True)\n",
            ")\n",
            "Linear(in_features=2, out_features=1, bias=True)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xevPxDGTP0-1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "aab0239c-35b7-455b-b452-5048683d32f7"
      },
      "source": [
        "for param in net.parameters():\n",
        "    print(param)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Parameter containing:\n",
            "tensor([[-0.3435, -0.1186]], requires_grad=True)\n",
            "Parameter containing:\n",
            "tensor([-0.5459], requires_grad=True)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UgI5OajnQFXv",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.4 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5t2vmymEQMLU",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "effb2b66-36ba-4842-85fa-8beee4bc33f8"
      },
      "source": [
        "from torch.nn import init\n",
        "\n",
        "init.normal_(net[0].weight, mean=0.0, std=0.01)\n",
        "init.constant_(net[0].bias, val=0.0)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Parameter containing:\n",
              "tensor([0.], requires_grad=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wlIthBCaQqGl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "c20c45d4-e12f-41dd-82d9-eb7f0d096c2e"
      },
      "source": [
        "for param in net.parameters():\n",
        "    print(param)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Parameter containing:\n",
            "tensor([[0.0207, 0.0033]], requires_grad=True)\n",
            "Parameter containing:\n",
            "tensor([0.], requires_grad=True)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7brJ7MPkQwY2",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.5 定义损失函数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VRODG-UbQ2pp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "loss = nn.MSELoss()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iBBwlhAZQ7Xw",
        "colab_type": "text"
      },
      "source": [
        "### 3.3.6 定义优化算法"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BNd-52WXRB92",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 151
        },
        "outputId": "0394b301-0d76-4993-b61c-d69ab28653d5"
      },
      "source": [
        "import torch.optim as optim  \n",
        "\n",
        "optimizer = optim.SGD(net.parameters(), lr=0.03)\n",
        "print(optimizer)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "SGD (\n",
            "Parameter Group 0\n",
            "    dampening: 0\n",
            "    lr: 0.03\n",
            "    momentum: 0\n",
            "    nesterov: False\n",
            "    weight_decay: 0\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D-OkbOgnRSO8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "e72c78a3-dde7-4303-9342-a343b399127b"
      },
      "source": [
        "num_epochs = 3 \n",
        "for epoch in range(1, num_epochs + 1):\n",
        "    for X, y in data_iter:\n",
        "        output = net(X)\n",
        "        l = loss(output, y.view(-1, 1))\n",
        "        optimizer.zero_grad()\n",
        "        l.backward()\n",
        "        optimizer.step()\n",
        "    print('epoch %d, loss: %f' %(epoch, l.item()))"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 1, loss: 0.000133\n",
            "epoch 2, loss: 0.000072\n",
            "epoch 3, loss: 0.000061\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VxcghBN_V8Wi",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "039e96fd-150e-478a-f09a-36b866ffdda4"
      },
      "source": [
        "dense = net[0]\n",
        "true_w, dense.weight.data"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "([2, -3.4], tensor([[ 2.0001, -3.4009]]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kpTz7kiuWOqX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "53270cd9-bf76-4366-aee5-e110d8083a91"
      },
      "source": [
        "true_b, dense.bias.data"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(4.2, tensor([4.1997]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    }
  ]
}